from mindspore import nn
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
from mindspore.train.dataset_helper import DatasetHelper, connect_network_with_dataset
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore import context
import os
import numpy as np
from cells import SigmoidCrossEntropyWithLogits, GenWithLossCell, DisWithLossCell, TrainOneStepCell, Reshape
import matplotlib.pyplot as plt
import time

"""训练平均用时 9.44s ，显存消耗1344m"""
batch_size = 128
epochs = 30
input_dim = 100
lr = 0.0002

context.set_context(mode=context.GRAPH_MODE, device_target="GPU")


def save_imgs(gen_imgs, idx):
    for i in range(gen_imgs.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(gen_imgs[i, 0, :, :] * 127.5 + 127.5, cmap="gray")
        plt.axis("off")
    plt.savefig("./images/{}.png".format(idx))


def create_dataset(data_path,
                   latent_size,
                   batch_size,
                   repeat_size=1,
                   num_parallel_workers=1):
    mnist_ds = ds.MnistDataset(data_path)
    hwc2chw_op = CV.HWC2CHW()

    mnist_ds = mnist_ds.map(
        input_columns="image",
        operations=lambda x: ((x - 127.5) / 127.5).astype("float32"),
        num_parallel_workers=num_parallel_workers,
    )
    mnist_ds = mnist_ds.map(
        input_columns="image",
        operations=hwc2chw_op,
        num_parallel_workers=num_parallel_workers,
    )
    mnist_ds = mnist_ds.map(
        input_columns="image",
        operations=lambda x: (
            x,
            np.random.normal(size=(latent_size)).astype("float32"),
        ),
        output_columns=["image", "latent_code"],
        column_order=["image", "latent_code"],
        num_parallel_workers=num_parallel_workers,
    )
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)

    return mnist_ds


class Generator(nn.Cell):
    """定义生成器结构"""
    def __init__(self, latent_size, auto_prefix=True):
        super(Generator, self).__init__(auto_prefix=auto_prefix)
        self.network = nn.SequentialCell()

        self.network.append(nn.Dense(latent_size, 256 * 7 * 7, has_bias=False))
        self.network.append(Reshape((-1, 256, 7, 7)))
        self.network.append(nn.BatchNorm2d(256))
        self.network.append(nn.ReLU())

        self.network.append(nn.Conv2dTranspose(256, 128, 5, 1))
        self.network.append(nn.BatchNorm2d(128))
        self.network.append(nn.ReLU())

        self.network.append(nn.Conv2dTranspose(128, 64, 5, 2))
        self.network.append(nn.BatchNorm2d(64))
        self.network.append(nn.ReLU())

        self.network.append(nn.Conv2dTranspose(64, 1, 5, 2))
        self.network.append(nn.Tanh())

    def construct(self, x):
        return self.network(x)


class Discriminator(nn.Cell):
    '''定义判别器结构'''
    def __init__(self, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.network = nn.SequentialCell()

        self.network.append(nn.Conv2d(1, 64, 5, 2))
        self.network.append(nn.BatchNorm2d(64))
        self.network.append(nn.LeakyReLU())

        self.network.append(nn.Conv2d(64, 128, 5, 2))
        self.network.append(nn.BatchNorm2d(128))
        self.network.append(nn.LeakyReLU())

        self.network.append(nn.Flatten())
        self.network.append(nn.Dense(128 * 7 * 7, 1))

    def construct(self, x):
        return self.network(x)


netG = Generator(input_dim)
netD = Discriminator()
loss = SigmoidCrossEntropyWithLogits()
netG_with_loss = GenWithLossCell(netG, netD, loss)
netD_with_loss = DisWithLossCell(netG, netD, loss)
optimizerG = nn.Adam(netG.trainable_params(), lr, beta1=0.5, beta2=0.999)
optimizerD = nn.Adam(netD.trainable_params(), lr, beta1=0.5, beta2=0.999)
net_train = TrainOneStepCell(netG_with_loss, netD_with_loss, optimizerG,
                             optimizerD)

ds = create_dataset(os.path.join("../data/MNIST_Data", "train"),
                    latent_size=input_dim,
                    batch_size=batch_size,
                    num_parallel_workers=2)
dataset_helper = DatasetHelper(ds, epoch_num=epochs, dataset_sink_mode=True)
net_train = connect_network_with_dataset(net_train, dataset_helper)

netG.set_train()
netD.set_train()
test_latent_code = Tensor(np.random.normal(size=(16, input_dim)),
                          dtype=mstype.float32)
for epoch in range(epochs):
    start = time.time()
    for data in dataset_helper:
        imgs = data[0]
        latent_code = data[1]
        d_out, g_out = net_train(imgs, latent_code)
    t = time.time() - start
    print("time of epoch {} is {:.2f}s".format(epoch, t))
    gen_imgs = netG(test_latent_code)
    save_imgs(gen_imgs.asnumpy(), epoch)

